from gym.spaces import Discrete, Box, MultiDiscrete, Dict
import numpy as np
import gym
import random, math, copy
import stack_mdp.utils as utils
from collections import OrderedDict


class BaseEnvMatrixDesignGame(gym.Env):
    def __init__(
            self,
            game,
            logger=None,
    ):
        self.game = game
        self.leader = "game_designer"
        self.logger = logger
        self.followers_list = game.list_of_agents

        self.num_followers = len(self.followers_list)

        # Second entry is a flag that determines whether the action taken on this obs has an effect
        self.observation_space = Dict({'base_environment': MultiDiscrete([1,2]),})

        self.action_space = Discrete(10) # Additional reward on diagonal

        # For now, treating followers separately instead of making this a MARL environment
        self.followers_observation_space = {follower: range(1) for follower in self.followers_list}
        self.followers_action_space = {follower: range(self.game.action_space(follower)) for follower in self.followers_list}


    def reset(self):
        self.followers_obs = {follower: 0 for follower in self.followers_list}
        self.get_followers_action = True

        observation = OrderedDict({"base_environment": [0,0]})
        return observation


    def step(self, action):
        self.get_followers_action = False

        utilities = self.compute_utilities(action, self.followers_actions)
        observation = OrderedDict({"base_environment": [0,0]})
        reward = 0 if self.followers_actions[self.followers_list[0]] == self.followers_actions[self.followers_list[1]] else 1
        utilities[self.leader] = reward
        return observation, reward, True, {"utilities": utilities}


    def compute_utilities(self, leader_action, followers_actions):
        action_dict = followers_actions.copy()
        action_dict[self.leader] = leader_action
        utilities = {agent: self.game.payoff(action_dict, agent) for agent in self.game.list_of_agents}

        if action_dict[self.game.list_of_agents[0]] == 0 and action_dict[self.game.list_of_agents[1]] == 0:
            utilities[self.game.list_of_agents[0]] += leader_action / self.game.norm_factor

        if action_dict[self.game.list_of_agents[0]] == 1 and action_dict[self.game.list_of_agents[1]] == 1:
            utilities[self.game.list_of_agents[1]] += leader_action / self.game.norm_factor

        return utilities


    def log_info(self, info):
        self.logger.record("supervisor_action", info["supervisor_action"])
        return



class BaseEnvSimpleMatrixGame(gym.Env):
    def __init__(
            self,
            game,
            followers_list,
            leader,
            logger=None,
            randomized=False,
            randomization_type="linear",
    ):
        self.randomized = randomized
        self.randomization_type = randomization_type
        self.game = game
        self.leader = leader
        self.logger = logger
        self.followers_list = followers_list

        self.num_followers = len(self.followers_list)

        # First entry is whether the action is meaningful; second entry is the follower's message
        self.observation_space = Dict({'base_environment': MultiDiscrete([1,2]),})

        # This should matrix action
        if self.randomized == True:
            self.action_space = Box(
                low=np.array([0 for _ in range(self.game.action_space(leader))]),
                high=np.array([10 for _ in range(self.game.action_space(leader))]),
                dtype=np.float32,
            )
        else:
            self.action_space = Discrete(self.game.action_space(leader))

        # For now, treating followers separately instead of making this a MARL environment
        self.followers_observation_space = {follower: range(1) for follower in self.followers_list}
        self.followers_action_space = {follower: range(self.game.action_space(follower)) for follower in self.followers_list}


    def reset(self):
        self.followers_obs = {follower: 0 for follower in self.followers_list}
        self.get_followers_action = True

        observation = OrderedDict({"base_environment": [0,0]})
        return observation


    def step(self, action):
        self.get_followers_action = False

        if self.randomized == True:
            utilities = self.randomized_step(action, self.followers_actions)
        else:
            utilities = self.deterministic_step(action, self.followers_actions)
        observation = OrderedDict({"base_environment": [0,0]})
        return observation, utilities[self.leader], True, {"utilities": utilities}


    def randomized_step(self, leader_action, followers_actions):
        action_probs = utils.weights_to_probs(leader_action)
        utilities = {agent: 0 for agent in self.game.list_of_agents}
        for curr_action in range(self.game.action_space(self.leader)):
            curr_utilities = self.deterministic_step(curr_action, followers_actions)
            for agent in self.game.list_of_agents:
                utilities[agent] = utilities[agent] + action_probs[curr_action] * curr_utilities[agent]
        return utilities


    def deterministic_step(self, leader_action, followers_actions):
        action_dict = followers_actions.copy()
        action_dict[self.leader] = leader_action
        utilities = {agent: self.game.payoff(action_dict, agent) for agent in self.game.list_of_agents}
        return utilities


    def log_info(self, info):
        return



class BaseSimpleMatrixBayesian(gym.Env):
    def __init__(
            self,
            games,
            followers_list,
            leader,
            num_messages,
            logger=None,
    ):
        self.games = games
        self.leader = leader
        self.logger = logger
        self.num_messages = num_messages
        self.followers_list = followers_list

        self.num_followers = len(self.followers_list)

        # First entry is whether the action is meaningful; second entry is the follower's message
        self.observation_space = Dict({
            'base_environment': MultiDiscrete([len(games),2]),
        })

        # This corresponds to matrix action. I am assuming matrix actions have same size
        self.action_space = Discrete(self.games[0].action_space(leader))

        # For now, treating followers separately instead of making this a MARL environment
        self.followers_observation_space = {follower: range(len(games)) for follower in self.followers_list}
        self.followers_action_space = {follower: range(num_messages) for follower in self.followers_list}
        self.freeze_types = False


    def reset(self):

        # Sampling game uniformly at random
        if not hasattr(self, "type_idxs") or not self.freeze_types:
            self.type_idxs = random.choices(range(len(self.games)), [1] * len(self.games))[0]

        # The following lines determine whether we need to query followers' actions before moving to the next step; if yes, set observation
        self.get_followers_action=True
        self.followers_obs = {follower: self.type_idxs for follower in self.followers_list}

        # Using this flag as you have different ways of running steps
        self.stepmode = 1

        # Return dummy observation
        observation = OrderedDict({"base_environment": [0,0]})
        return observation


    def step(self, action):

        # In the first step, we let the leader observe the follower's message
        if self.stepmode==1:
            self.stepmode+=1
            self.get_followers_action = False

            # Observation is given by follower message
            observation_entry = [self.followers_actions[follower] for follower in self.followers_list]
            observation_entry.append(1) # This time observation is meaningful
            observation = OrderedDict({"base_environment": observation_entry})
            return observation, 0, False, {}

        # Then, the leader selects a row and the follower best responds
        if self.stepmode>1:
            game = self.games[self.type_idxs]

            # Assume follower's best response
            follower_utility = -float("inf")
            for follower_action in range(game.action_space(game.agent_name_to_idx(self.followers_list[0]))):
                action_dict_temp = {self.leader: action, self.followers_list[0]: follower_action}
                temp_utility = game.payoff(action_dict_temp, self.followers_list[0])
                if temp_utility>follower_utility:
                    follower_utility = temp_utility
                    action_dict = action_dict_temp

            utilities = {agent: game.payoff(action_dict, agent) for agent in game.list_of_agents}
            observation = OrderedDict({"base_environment": [0, 0]})
            return observation, utilities[self.leader], True, {"utilities": utilities}


    def log_info(self, info):
        self.logger.record("sampled_game", self.type_idxs)


# Right now we only do Prophet Inequality!
class BaseMessageSPM(gym.Env):
    def __init__(
            self,
            num_messages,
            discrete_prices=False,
            logger=None,
    ):

        # TODO: The following is only for PI. Make more general!
        self.EPSILON = 0.2
        self.types_table = [[0.5 / (1/(2*self.EPSILON)), 1], [0, 1 / (1/(2*self.EPSILON)) ]]
        self.num_diff_items = 1
        self.units_per_item = [1]
        self.followers_list = ["agent_0", "agent_1"]
        self.leader = "spm_designer"


        self.num_followers = len(self.followers_list)
        self.outcome = {}
        self.count_episodes = 0
        self.logger = logger
        self.discrete_prices = discrete_prices

        self.observation_space = Dict({
            'base_environment': Box(
                low = np.array([
                    0 for _ in range(self.num_followers +  # Agents left
                    len(self.units_per_item) +  # Items left
                    2 * self.num_diff_items * self.num_followers +  # Prices and Allocations (hence 2*)
                    self.num_followers +  # Bids
                    1 # Dummy obs flag
                )]),
                high = np.array(
                    [1 for _ in range(self.num_followers)] +
                    self.units_per_item +
                    [1 for _ in range(2 * self.num_diff_items * self.num_followers)] +
                    [1] * self.num_followers + # Bids
                    [1] # Dummy obs flag
                )
            )
        })

        if discrete_prices:
            self.discrete_price_vec = self.get_discrete_price_vec()
            self.action_space = Box(
                                    low=np.array([0 for i in range(
                                        self.num_followers +
                                        self.num_diff_items*len(self.discrete_price_vec))]),
                                    high=np.array([1 for i in range(
                                       self.num_followers +
                                       self.num_diff_items * len(self.discrete_price_vec))]),
                                    dtype=np.float32)
        else:
            self.action_space = Box(low=np.array([0 for i in range(self.num_followers + self.num_diff_items)]),
                                high=np.array([1 for i in range(self.num_followers + self.num_diff_items)]),
                                dtype=np.float32)

        #TODO: the following line only works when the followers have the same number of types
        self.followers_observation_space = {follower: range(len(self.types_table[0])) for follower in self.followers_list}
        self.followers_action_space = {follower: range(num_messages) for follower in self.followers_list}
        self.freeze_types = False


    def get_discrete_price_vec(self):
        all_types_vec = list(np.concatenate([self.types_table[i] for i in range(self.num_followers)]).flat)
        all_types_vec.sort()

        # Determine minimum gap between types
        delta_min = min([all_types_vec[i + 1] - all_types_vec[i] for i in range(len(all_types_vec) - 1)])

        return np.concatenate(([0], all_types_vec + delta_min/2))


    def reset(self):

        # Sample types
        if not hasattr(self, "type_idxs") or not self.freeze_types:
            self.type_idxs = [random.choices(range(len(self.types_table[0])), [1-self.EPSILON, self.EPSILON])[0], random.choices(range(len(self.types_table[1])), [0.5, 0.5])[0]]
            self.valuations = [[self.types_table[i][self.type_idxs[i]] ] for i in range(self.num_followers)]

        # The following lines determine whether we need to query followers' actions before moving to the next step; if yes, set observation
        self.get_followers_action = True
        self.followers_obs = {self.followers_list[i]: self.type_idxs[i] for i in range(self.num_followers)}

        # Using this flag as you have different ways of running steps
        self.stepmode = 1

        # Reset records
        self.overall_value = 0
        self.num_agents_left = self.num_followers
        self.num_items_left = np.sum(self.units_per_item)
        self.outcome['order'] = []
        self.outcome['prices'] = []
        self.outcome['mechanism_outcome'] = {}
        self.utilities = {follower:0 for follower in self.followers_list}

        # Return dummy observation
        self.state = np.concatenate((
                                    np.ones(self.num_followers), # Agents left
                                    np.asarray(self.units_per_item), # Items left
                                    np.ones(self.num_diff_items * self.num_followers), # Allocation matrix (subtract when you allocate)
                                    np.zeros(self.num_diff_items * self.num_followers),
                                    [0,0,0])) # Bids and dummy obs

        observation_entry = copy.deepcopy(self.state)
        observation = OrderedDict({"base_environment": observation_entry})
        return observation


    def step(self, action):

        # In the first step, we let the leader observe the follower's message
        if self.stepmode==1:

            self.stepmode+=1
            self.get_followers_action = False

            # Add bids to state and make it not dummy
            self.bids = [self.followers_actions[follower] for follower in self.followers_list]
            self.state[-(self.num_followers + 1):] = self.bids + [1]

            observation_entry = copy.deepcopy(self.state)
            observation = OrderedDict({"base_environment": observation_entry})
            return observation, 0, False, {}

        # Then, we run SPM
        if self.stepmode>1:
            agent_scores = action[:self.num_followers]
            agent_scores = [agent_scores[i] if self.state[i] else -math.inf for i in range(self.num_followers)]
            prices = action[self.num_followers:]

            if self.discrete_prices:
                prices = np.array([self.discrete_price_vec[np.argmax(
                    prices[i * len(self.discrete_price_vec):(i + 1) * len(self.discrete_price_vec)])] for i in
                                   range(self.num_diff_items)])

            agent_idx = np.argmax(agent_scores)
            item_idx = self.buyer(agent_idx, prices)

            self.outcome['order'].append(agent_idx)
            self.outcome['prices'].append(prices.tolist())  # TODO: This doesn't work for multiple heterogeneous items

            # Remove agent from state
            self.state[agent_idx] = 0
            self.num_agents_left -= 1
            reward = 0

            # If agent buys, update allocation and items_left in state
            # TODO: This only works for 1 item!
            if item_idx is not -1:
                self.num_items_left -= 1
                self.overall_value += self.valuations[agent_idx][item_idx]

                # Remove item from state
                self.state[self.num_followers + item_idx] -= 1

                # Update allocation matrix
                self.state[self.num_followers + len(self.units_per_item) + self.num_diff_items * agent_idx + item_idx] -= 1

                # Add agent to outcome
                self.outcome['mechanism_outcome']['agent_%i' % (agent_idx)] = {'allocation': item_idx, 'payment': prices[item_idx]}


            # Update prices in state
            available_units_per_item = self.state[self.num_followers:self.num_followers + len(self.units_per_item)]
            prices_available = [0 if available_units_per_item[i] else 1 for i in range(self.num_diff_items)]
            self.state[self.num_followers + len(
                self.units_per_item) + self.num_diff_items * self.num_followers + self.num_diff_items * agent_idx: self.num_followers + len(
                self.units_per_item) + self.num_diff_items * self.num_followers + self.num_diff_items * (agent_idx + 1)] = prices_available


            observation_entry = copy.deepcopy(self.state)
            observation = OrderedDict({"base_environment": observation_entry})
            done = self.num_agents_left <= 0 or self.num_items_left <= 0
            info = {}

            if done:

                self.count_episodes += 1

                # TODO: get max social welfare via setting-specific WD
                self.max_social_welfare = 0
                for j in range(self.num_diff_items):
                    sorted_vals = [self.valuations[i][j] for i in range(len(self.valuations))]
                    sorted_vals.sort(reverse=True)
                    self.max_social_welfare += sum(sorted_vals[:self.units_per_item[0]])

                if self.overall_value > self.max_social_welfare + 0.0000000001:
                    print("Oops")
                    pass
                    pass

                reward = self.overall_value - self.max_social_welfare
                self.utilities['spm_designer'] = reward
                info["utilities"] = self.utilities

            return observation, reward, done, info


    def log_info(self, info):
        self.logger.record("count", self.count_episodes)
        if self.max_social_welfare == 0:
            allocative_efficiency = 1.0
        else:
            allocative_efficiency = self.overall_value / self.max_social_welfare
        self.logger.record("efficiency", "%.5f" % allocative_efficiency)
        self.logger.record("overall_value", "%.5f" % self.overall_value)
        self.logger.record("opt", "%.5f" % self.max_social_welfare)

        for i in range(len(self.outcome['order'])):
            self.logger.record("order_%i" % i, self.outcome['order'][i])
            # The next line only works in settings with 1 item!
            self.logger.record("price_%i" % i, self.outcome['prices'][i][0])

        for i in range(self.num_followers):
            self.logger.record("bids_agent_%i" % i, self.bids[i])
            # The next line only works in settings with 1 item!
            self.logger.record("value_agent_%i" % i, self.valuations[i][0])


    def buyer(self, agent, prices):
        valuation = self.valuations[agent]
        available_units_per_item = self.state[self.num_followers:self.num_followers + len(self.units_per_item)]
        utility = [valuation[i] - prices[i] if available_units_per_item[i] else -math.inf for i in
                   range(self.num_diff_items)]
        choice = np.argmax(utility)
        self.utilities[self.followers_list[agent]] = max(utility[choice],0)
        return choice if utility[choice] >= 0 else -1



class RLSupervisorQFollowersWrapper(gym.Wrapper):

    def __init__(
            self,
            env,
            alpha=0.15,
            delta=0.95,
            beta=0.00001,
    ):
        super(RLSupervisorQFollowersWrapper, self).__init__(env)

        self.delta = delta
        self.beta = beta
        self.alpha = alpha
        self.this_step_mode = "equilibrium" # We set equilibrium mode as default

        self.q_init() # Needed to initialize observation space in StackPOMDP wrapper



    def reset(self):
        self.q_init() # Initialize agents' q_table
        self.step_counter = 0
        obs = self.env.reset() # Restart sub_env
        self.sub_env_done = False
        return obs


    def step(self, action):

        self.step_counter = self.step_counter+1

        if self.sub_env_done:
            obs = self.env.reset()  # Restart sub_env
            self.sub_env_done = False
            return obs, 0, False, {}

        if self.env.get_followers_action: self.env.followers_actions = self.get_followers_actions(self.env.followers_obs, self.this_step_mode)

        obs, reward, done, info = self.env.step(action)

        if done:
            self.sub_env_done = True

            # Record for logging
            info["followers_actions"] = self.followers_actions
            info["sub_env_done"] = True

            # Update agents' tables
            if self.this_step_mode == 'equilibrium':
                self.update_q_tables("", self.followers_actions, info["utilities"], self.env.followers_obs)

        return obs, reward, False, info


    def q_matrices_to_norm_vec(self):
        q_matrices = np.empty((0))
        for follower in self.q_tables.keys():
            q_table_vec = np.array(list(self.q_tables[follower].values())).flatten()
            max_abs = max(abs(q_table_vec))
            if max_abs>0: q_table_vec = q_table_vec / max_abs
            q_matrices = np.append(q_matrices, q_table_vec)
        return q_matrices


    def q_init(self):
        self.q_tables = {follower : {} for follower in self.followers_list}
        for agent in self.followers_list:
            for observation in self.followers_observation_space[agent]:
                self.q_tables[agent][str(observation)] = [0] * len(self.followers_action_space[agent])


    def get_followers_actions(self, observation, action_type="equilibrium"):

        followers_actions = {}

        if action_type == "reward":
            for agent in self.followers_list:
                followers_actions[agent] = np.argmax(self.q_tables[agent][str(observation[agent])])

        if action_type == "equilibrium":
            for agent in self.followers_list:
                epsilon = np.exp(-1 * self.beta * self.step_counter)
                if random.uniform(0, 1) < epsilon:
                    followers_actions[agent] = random.randint(0, len(self.q_tables[agent][str(observation[agent])])-1)
                else:
                    followers_actions[agent] = np.argmax(self.q_tables[agent][str(observation[agent])])

        return followers_actions


    def update_q_tables(self, next_observation, actions_dict, reward, prev_observation):
        last_values = {agent: 0 for agent in self.followers_list}
        Q_maxes = {agent: 0 for agent in self.followers_list}

        for agent in self.followers_list:
            obs = str(prev_observation[agent])
            last_values[agent] = self.q_tables[agent][obs][actions_dict[agent]]
            # Q_maxes[agent] = np.max(self.q_tables[agent][next_observation])
            self.q_tables[agent][obs][actions_dict[agent]] = \
                ((1 - self.alpha) * last_values[agent]) + (self.alpha * (reward[agent] + self.delta * Q_maxes[agent]))


    def log_info(self, info):
        self.logger.record("q_tables", str(self.q_tables))
        self.env.log_info(info)



class RLSupervisorMWFollowersWrapper(gym.Wrapper):
    CLIP_MIN = 0.001
    CLIP_ITERATIONS = 0
    DEFAULT_EPS = 0.01

    def __init__(
            self,
            env,
            epsilon=DEFAULT_EPS,
            clip_min=CLIP_MIN,
            clip_iterations=CLIP_ITERATIONS,
    ):

        super(RLSupervisorMWFollowersWrapper, self).__init__(env)

        self.epsilon = epsilon
        self.clip_min = clip_min
        self.clip_iterations = clip_iterations

        self.utilities_table = [[[0 for _ in self.followers_action_space[follower]]
                                for _ in self.followers_observation_space[follower]]
                                for follower in self.followers_list]

        self.step_counter = 0
        self.this_step_mode = "equilibrium" # We set equilibrium mode as default

        self.weights = [np.ones((len(self.followers_observation_space[i]),len(self.followers_action_space[i])),
                                dtype=np.float64) for i in self.followers_list] # Needed to initialize observation space in StackPOMDP wrapper


    def reset(self):

        self.weights = [np.ones((len(self.followers_observation_space[i]),len(self.followers_action_space[i])),
                                dtype=np.float64) for i in self.followers_list]

        # The following indexes keep track of the next counterfactual action to try
        self.follower_idx = 0
        self.action_idx = 0

        # Set this to true after testing all counterfactual actions. It will sample new types/action profile
        self.mw_iteration_done = True

        obs_sub_env = self.env.reset() # Restart sub_env

        # If equilibrium step, we freeze types while we try counterfactual actions
        self.env.freeze_types = True if self.this_step_mode == "equilibrium" else False
        self.sub_env_done = False

        return obs_sub_env


    def step(self, action):
        self.step_counter = self.step_counter+1

        if self.sub_env_done:
            obs = self.env.reset()  # Restart sub_env
            self.sub_env_done = False

            # If equilibrium step, we freeze types while we try counterfactual actions
            self.env.freeze_types = True if self.this_step_mode == "equilibrium" else False
            return obs, 0, False, {}

        if self.env.get_followers_action: self.env.followers_actions = self.get_followers_actions(self.env.followers_obs)

        obs, reward, done, info = self.env.step(action)

        if done:

            self.sub_env_done = True

            # Update agents' tables
            info["followers_actions"] = self.followers_actions
            info["sub_env_done"] = True

            # We save utility generated by counterfactual action to update weights later
            if self.this_step_mode != "reward":
                follower = self.followers_list[self.follower_idx]
                self.utilities_table[self.follower_idx][self.env.followers_obs[follower]][self.action_idx] = \
                    info['utilities'][follower]

            # Update counterfactual action idxs
            self.update_idxs()

            if self.mw_iteration_done and self.this_step_mode == "equilibrium": self.update_weights()

            # We want to sample new types next time we reset our subenv, even in equilibrium steps
            if self.mw_iteration_done or self.this_step_mode == "reward": self.env.freeze_types = False

        return obs, reward, done, info


    def update_idxs(self):
        follower = self.followers_list[self.follower_idx]
        if self.action_idx < len(self.followers_action_space[follower]):
            self.action_idx += 1
        if self.action_idx == len(self.followers_action_space[follower]) and self.follower_idx < self.num_followers - 1:
            self.follower_idx += 1
            self.action_idx = 0
        if self.action_idx == len(
                self.followers_action_space[follower]) and self.follower_idx == self.num_followers - 1:
            self.mw_iteration_done = True
            self.follower_idx = 0
            self.action_idx = 0


    def get_followers_actions(self, types):

        # Sample new actions if we are done with counterfactuals
        if self.this_step_mode == "reward":
            self.current_actions = {}
            for agent in range(len(self.followers_list)):
                self.current_actions[self.followers_list[agent]] = \
                    np.argmax(self.weights[agent][types[self.followers_list[agent]]])
            return self.current_actions

        # Sample new actions if we are done with counterfactuals
        if self.mw_iteration_done == True:
            self.current_actions = {}
            for agent in range(len(self.followers_list)):
                self.current_actions[self.followers_list[agent]] = \
                    random.choices(self.followers_action_space[self.followers_list[agent]],
                                   self.weights[agent][types[self.followers_list[agent]]])[0]
            self.mw_iteration_done = False

        follower = self.followers_list[self.follower_idx]
        counterfactual_action = self.action_idx
        counterfactual_actions = self.current_actions.copy()
        counterfactual_actions[follower] = counterfactual_action

        return counterfactual_actions


    def weights_to_norm_vec(self):
        weights = np.empty((0))
        for follower in range(len(self.followers_list)):
            weights_vec = self.weights[follower].flatten()
            max_abs = max(abs(weights_vec))
            if max_abs > 0: weights_vec = weights_vec / max_abs
            weights = np.append(weights, weights_vec)
        return weights


    def update_weights(self):

        # Update weights
        for agent_idx, agent in enumerate(self.followers_list):
            for action in self.followers_action_space[agent]:
                self.weights[agent_idx][self.env.followers_obs[agent]][action] *= \
                    (1 + self.epsilon) ** self.utilities_table[agent_idx][self.env.followers_obs[agent]][action]

        # Clip weights after self.clip_iterations iterations
        if self.clip_iterations>0 and ((self.step_counter - self.clip_iterations + 1) % self.clip_iterations == 0):
            for agent_idx, agent in enumerate(self.followers_list):
                    max_weight = max(self.weights[agent_idx][self.env.followers_obs[agent]])
                    for action in self.followers_action_space[agent]:
                        self.weights[agent_idx][self.env.followers_obs[agent]][action] = max(
                            self.weights[agent_idx][self.env.followers_obs[agent]][action] / max_weight, self.clip_min)


    def log_info(self, info):
        self.logger.record("weights", str(self.weights))
        self.env.log_info(info)



class StackMDPWrapper(gym.Wrapper):

    def __init__(
            self,
            env,
            tot_num_eq_steps=1000,
            tot_num_reward_steps=10,
            frac_excluded_eq_steps=0,
            critic_obs="flag",
    ):

        super(StackMDPWrapper, self).__init__(env)

        # This sets the total number of equilibrium and reward steps in StackMDP
        self.tot_num_eq_steps = tot_num_eq_steps
        self.tot_num_reward_steps = tot_num_reward_steps
        self.critic_obs = critic_obs

        self.tot_num_steps = 0

        self.frac_excluded_eq_steps = frac_excluded_eq_steps

        # Set up observation space:
        #    entry 'base_environment' contains the part of observation for which action may be fixed during a StackMDP episode
        #    whatever starts with 'critic:' will only be seen by critic network
        if self.critic_obs == "flag":
            self.observation_space = Dict({
                'base_environment': self.env.observation_space['base_environment'],
                'critic:is_reward_step': Discrete(2),
            })
        elif self.critic_obs == "full":
            if hasattr(self, "q_tables"):
                num_q_entries = 0
                for follower in self.q_tables.keys():
                    num_q_entries = num_q_entries + len(np.array(list(self.q_tables[follower].values())).flatten())
                self.observation_space = Dict({
                    'base_environment': self.env.observation_space['base_environment'],
                    'critic:is_reward_step': Discrete(2),
                    'critic:exploration_rates': Box(low=0, high=1.0, shape=(len(self.env.followers_list),)),
                    'critic:Q_matrices': Box(low=-1.0, high=1.0, shape=(num_q_entries,)),
                })
            elif hasattr(self, "weights"):
                num_weights = 0
                for follower in range(len(self.env.followers_list)):
                    num_weights = num_weights + len(self.weights[follower].flatten())
                self.observation_space = Dict({
                    'base_environment': self.env.observation_space['base_environment'],
                    'critic:is_reward_step': Discrete(2),
                    'critic:weights': Box(low=-1.0, high=1.0, shape=(num_weights,)),
                })


    def reset(self):

        self.env.this_step_mode = "equilibrium"
        obs_sub_env = self.env.reset() # Restart sub_env

        self.eq_steps_counter = 0
        self.reward_steps_counter = 0

        # The following lines are needed to exclude steps from replay buffers
        from random import sample
        number_of_excluded_steps = int(self.frac_excluded_eq_steps * self.tot_num_eq_steps)
        self.excluded_indexes = sample(list(range(1, self.tot_num_eq_steps+1)), k=number_of_excluded_steps)
        self.excluded_indexes.sort()

        full_observation = OrderedDict({"base_environment": obs_sub_env["base_environment"]})
        self.augment_observation(full_observation)
        return full_observation


    def augment_observation(self, observation, is_reward_step=0):
        if self.critic_obs == "flag" or self.critic_obs == "full":
            observation["critic:is_reward_step"] = is_reward_step
        if self.critic_obs == "full":
            if hasattr(self, "q_tables"):
                observation["critic:Q_matrices"] = self.q_matrices_to_norm_vec()
                observation['critic:exploration_rates'] = np.array([np.exp(-1 * self.env.beta * self.env.step_counter) for agent in range(len(self.env.followers_list))])
            elif hasattr(self, "weights"):
                observation["critic:weights"] = self.weights_to_norm_vec()


    def step(self, action):

        if not hasattr(self, "is_eval"):
            self.tot_num_steps += 1

        if self.tot_num_steps % 100000 == 0:
            print("StackMDPWrapper steps completed: ", self.tot_num_steps)

        if self.eq_steps_counter < self.tot_num_eq_steps:

            # We do an equilibrium step
            self.eq_steps_counter+=1

            if self.eq_steps_counter == self.tot_num_eq_steps:
                self.env.this_step_mode = "reward"
            else:
                self.env.this_step_mode = "equilibrium"

            obs, _, done, info = self.env.step(action)

            full_observation = OrderedDict({"base_environment": obs["base_environment"]})
            self.augment_observation(full_observation)

            # Checks if step should be excluded from buffer
            info["exclude_from_buffer"] = False
            if len(self.excluded_indexes) > 0 and self.eq_steps_counter == self.excluded_indexes[0]:
                self.excluded_indexes.pop(0)
                info["exclude_from_buffer"] = True

            return full_observation, 0, False, info

        elif self.reward_steps_counter < self.tot_num_reward_steps:

            # We do a reward step
            self.reward_steps_counter+=1

            self.env.this_step_mode = "reward"
            obs, reward, done, info = self.env.step(action)

            info["exclude_from_buffer"] = False

            if hasattr(self, "is_eval"):
                # We log only at the end of subepisodes
                if info.__contains__("sub_env_done") and info["sub_env_done"]:
                    info["supervisor_action"] = action
                    self.log_info(info)

            full_observation = OrderedDict({"base_environment": obs["base_environment"]})
            self.augment_observation(full_observation, is_reward_step=1)

            done = True if self.reward_steps_counter==self.tot_num_reward_steps else False

        return full_observation, reward, done, info


    def log_info(self, info):

        if hasattr(self, "is_eval"):
            self.tot_num_steps += 1

        self.env.log_info(info)
        self.logger.record("count_steps", self.tot_num_steps)

        if isinstance(self.unwrapped, BaseEnvSimpleMatrixGame) or isinstance(self.unwrapped, BaseSimpleMatrixBayesian):
            self.logger.record("supervisor_action", info["supervisor_action"])
            # self.logger.record("obs_action_map", str(self.model.policy.obs_action_map))

        self.logger.record("leader_reward", info['utilities'][self.leader])
        for follower in self.followers_list:
            self.logger.record(follower+"_reward", info['utilities'][follower])
            self.logger.record(follower+"_action", info['followers_actions'][follower])
        self.logger.dump(self.tot_num_steps)